from typing import List


def numberOfRightTriangles(grid: List[List[int]]) -> int:
    m,n = len(grid),len(grid[0])
    col = [0]*n
    for i in range(n) :
        for j in range(m):
            col[i] +=grid[j][i]
    res = 0
    for i in range(m):
        row = sum(grid[i])
        for j in range(n):
            if grid[i][j] == 1 :
                res += (row-1)* (col[j]-1)
    return res
